Animal detection in the wild using Computer Vision and AI¶

Task 4. Computer vision: object detection, classification, segmentation, tracking

Team Member: Akom, Jerry

In [1]:
import os
import random
import shutil
from glob import glob
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import precision_recall_fscore_support

import tensorflow as tf

1. Prepare Dataset¶

1.1. Set Folders¶

Data Source: Animal Image Dataset (90 Different Animals)

In [2]:
# Data dir
import kagglehub
input_data_soure = kagglehub.dataset_download(
    'iamsouravbanerjee/animal-image-dataset-90-different-animals'
)
os.listdir(input_data_soure)
Out[2]:
['animals', 'name of the animals.txt']
In [3]:
# import kagglehub
input_dir = f'{input_data_soure}/animals/animals'
output_dir = './cis730_term_project/data'

train_dir = "./cis730_term_project/data/train"
val_dir = "./cis730_term_project/data/val"
test_dir = "./cis730_term_project/data/test"
checkpoint_dir = "./cis730_term_project/checkpoints"
results_dir = "./cis730_term_project/results"
image_dir = "./cis730_term_project/images"

for dir_path in [
    output_dir, train_dir, val_dir, test_dir,
    checkpoint_dir, results_dir, image_dir
]:
    os.makedirs(dir_path, exist_ok=True)

1.2. Vsualize Data Sample¶

In [4]:
class_folders = [f for f in Path(input_dir).iterdir() if f.is_dir()]

sampled_classes = random.sample(class_folders, min(10, len(class_folders)))
plt.figure(figsize=(15, 5))

for idx, class_path in enumerate(sampled_classes):
    images = list(class_path.glob("*.jpg"))
    if not images:
        continue
    img_path = random.choice(images)
    img = Image.open(img_path)
    plt.subplot(2, 5, idx + 1)
    plt.imshow(img)
    plt.title(class_path.name)
    plt.axis("off")

plt.tight_layout()
plt.savefig(f"{image_dir}/data_sample.png")
plt.show()
No description has been provided for this image

1.3. Prepare Train / Validation / Test Sets¶

In [5]:
train_val_test_split = [0.7, 0.15, 0.15]

for class_path in class_folders:
    class_name = class_path.name
    images = list(class_path.glob("*.jpg"))
    if len(images) < 3:
        continue

    # First split into train and temp (val+test)
    train_imgs, temp_imgs = train_test_split(
        images, train_size=train_val_test_split[0], random_state=42
    )

    # Then split temp into val and test
    val_imgs, test_imgs = train_test_split(
        temp_imgs,
        train_size=train_val_test_split[1] / (train_val_test_split[1] + train_val_test_split[2]),
        random_state=42
    )

    # Now write to folders
    for subset_name, subset_imgs in zip(["train", "val", "test"], [train_imgs, val_imgs, test_imgs]):
        dest_class_dir = Path(output_dir) / subset_name / class_name
        for img_path in subset_imgs:
            try:
                dest_class_dir.mkdir(parents=True, exist_ok=True)
                dest_file_path = dest_class_dir / img_path.name
                shutil.copy(str(img_path), str(dest_file_path))
            except FileNotFoundError:
                print(f"File not found: {img_path}")
            except Exception as e:
                print(f"Error copying {img_path}: {e}")

2. Modelling (No Data Augmentation)¶

2.1. Modeling Functions¶

2.1.1. Data Preparation Functions¶

In [6]:
# Image preparation function
def load_image_dataset(
    data_dir,
    image_size=(224, 224),
    batch_size=32,
    label_mode="categorical",
    shuffle=True,
    seed=123,
    model_name="mobilenet",
    augment=False
):
    if model_name == "mobilenet":
        if not augment:
            generator = tf.keras.preprocessing.image.ImageDataGenerator(
                preprocessing_function = tf.keras.applications.mobilenet.preprocess_input,
            )
        else:
            print(f'Performing data augmentations for {model_name} model')
            print()
            generator = tf.keras.preprocessing.image.ImageDataGenerator(
                preprocessing_function = tf.keras.applications.mobilenet.preprocess_input,
                rotation_range=20,
                width_shift_range=0.1,
                height_shift_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True
            )

        data_tmp = generator.flow_from_directory(
            data_dir,
            target_size=image_size,
            batch_size=batch_size,
            class_mode=label_mode,
            shuffle=shuffle,
            seed=seed
        )

    elif model_name == "efficientnet":
        if not augment:
            generator = tf.keras.preprocessing.image.ImageDataGenerator(
                preprocessing_function = tf.keras.applications.efficientnet.preprocess_input,
            )
        else:
            print(f'Performing data augmentations for {model_name} model')
            print()
            generator = tf.keras.preprocessing.image.ImageDataGenerator(
                preprocessing_function = tf.keras.applications.efficientnet.preprocess_input,
                rotation_range=20,
                width_shift_range=0.1,
                height_shift_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True
            )

        data_tmp = generator.flow_from_directory(
            data_dir,
            target_size=image_size,
            batch_size=batch_size,
            class_mode=label_mode,
            shuffle=shuffle,
            seed=seed
        )

    elif model_name == "resnet":
        if not augment:
            generator = tf.keras.preprocessing.image.ImageDataGenerator(
                preprocessing_function = tf.keras.applications.resnet50.preprocess_input,
            )
        else:
            print(f'Performing data augmentations for {model_name} model')
            print()
            generator = tf.keras.preprocessing.image.ImageDataGenerator(
                preprocessing_function = tf.keras.applications.resnet50.preprocess_input,
                rotation_range=20,
                width_shift_range=0.1,
                height_shift_range=0.1,
                zoom_range=0.1,
                horizontal_flip=True
            )

        data_tmp = generator.flow_from_directory(
            data_dir,
            target_size=image_size,
            batch_size=batch_size,
            class_mode=label_mode,
            shuffle=shuffle,
            seed=seed
        )

    return data_tmp

2.1.2. Model Build / Train functions¶

In [7]:
# Model Compile Parameters
compile_params = {
    'optimizer': 'adam',
    'loss': 'categorical_crossentropy',
    'metrics': ['accuracy']
}

# Build Models - 3 Models Considered: MobileNetV2, EfficientNetB3, ResNet50
def build_model(
    model_name="mobilenet",
    input_shape=(224, 224, 3),
    num_classes=10,
    compile_params=compile_params
):
    if model_name == "mobilenet":
        base_model = tf.keras.applications.MobileNetV2(
            input_shape=input_shape,
            include_top=False,
            weights="imagenet"
        )

    elif model_name == "efficientnet":
        base_model = tf.keras.applications.EfficientNetB3(
            input_shape=(300, 300, 3),
            include_top=False,
            weights="imagenet"
        )

    elif model_name == "resnet":
        base_model = tf.keras.applications.ResNet50(
            input_shape=input_shape,
            include_top=False,
            weights="imagenet"
        )

    else:
        raise ValueError("Choose one of: 'mobilenet', 'efficientnet', or 'resnet'")

    base_model.trainable = False  # Freeze base

    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(
            num_classes,
            activation="softmax"
        )
    ])

    model.compile(**compile_params)

    return model

# Callbacks
def get_early_stopping_callback(
    monitor="val_loss",
    patience=3,
	restore_best_weights=True,
    verbose=1
):
    return tf.keras.callbacks.EarlyStopping(
        monitor=monitor,
        patience=patience,
        restore_best_weights=restore_best_weights,
        verbose=verbose,
    )

def get_check_point_callback(
	checkpoint_path="MobileNetV2CheckPoint.keras",
    monitor="val_loss",
    save_best_only=True,
    verbose=1
):
    return tf.keras.callbacks.ModelCheckpoint(
		checkpoint_path,
        monitor=monitor,
        save_best_only=save_best_only,
        verbose=verbose,
    )

# Predict on test data
def get_predictions(dataset, model):
    y_true = dataset.classes
    y_pred = np.argmax(model.predict(dataset), axis = 1)
    return y_true, y_pred

2.1.3. Visualization functions¶

In [8]:
# To visualize training loss/accuracy
def plot_training_loss_accuracy(history, model_name="mobilenet", output_dir=image_dir):
    if model_name == "mobilenet":
        title_loss = "Loss MobileNetV2"
        title_acc = "Accuracy MobileNetV2"
    elif model_name == "efficientnet":
        title_loss = "Loss EfficientNetB3"
        title_acc = "Accuracy EfficientNetB0"
    elif model_name == "resnet":
        title_loss = "Loss ResNet50"
        title_acc = "Accuracy ResNet50"
    else:
        raise ValueError("Choose one of: 'mobilenet', 'efficientnet', or 'resnet'")

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Train')
    plt.plot(history.history['val_loss'], label='Val')
    plt.title(title_loss)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Train')
    plt.plot(history.history['val_accuracy'], label='Val')
    plt.title(title_acc)
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid()

    plt.savefig(f"{output_dir}/{model_name}_accuracy_loss.png")
    plt.show()

def get_best_worst(y_true, y_pred, class_names):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Calculate score
    precision, recall, f1, _ = \
        precision_recall_fscore_support(
            y_true,
            y_pred,
            labels=range(len(list(class_names)))
    )

    # Sort classes by F1 score
    sorted_indices = np.argsort(f1)
    worst_5 = sorted_indices[:5]
    best_5 = sorted_indices[-5:]


    print("Best 5 classes (F1 Score):")
    print([class_names[i] for i in best_5])
    print()
    print("Worst 5 classes (F1 Score):")
    print([class_names[i] for i in worst_5])

    return best_5, worst_5

def plot_confusion_for_classes(
    y_true, y_pred, selected_best, selected_worst, class_names, model_name,
    output_dir=image_dir
):
    # Filter the relevant rows and columns
    cm = confusion_matrix(y_true, y_pred)
    sub_cm_best = cm[np.ix_(selected_best, selected_best)]
    labels_best = [class_names[i] for i in selected_best]

    sub_cm_worst = cm[np.ix_(selected_worst, selected_worst)]
    labels_worst = [class_names[i] for i in selected_worst]

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    # Display confusion matrix 1
    disp1 = ConfusionMatrixDisplay(
        confusion_matrix=sub_cm_best, display_labels = labels_best
    )
    disp1.plot(ax=axes[0], colorbar=False)
    axes[0].set_title("Best 5 Predicted Classes (F1 Score)")
    axes[0].tick_params(axis='x', rotation=45)

    # Display confusion matrix 2
    disp2 = ConfusionMatrixDisplay(
        confusion_matrix=sub_cm_worst, display_labels = labels_worst
    )
    disp2.plot(ax=axes[1], colorbar=False)
    axes[1].set_title("Worst 5 Predicted Classes (F1 Score)")
    axes[1].tick_params(axis='x', rotation=45)

    fig.tight_layout()
    fig.savefig(
        f"{output_dir}/{model_name}_confusion_matrix.png", bbox_inches='tight'
    )
    plt.show()

def show_wrong_predictions(
    y_true, y_pred, wrong_indices, test_dir, model_name, num_images=10,
    output_dir=image_dir
):
    test_paths = list(Path(test_dir).rglob("*.jpg"))
    test_paths.sort()
    plt.figure(figsize=(15, 5))
    for i, idx in enumerate(wrong_indices[:num_images]):
        image_path = str(test_paths[idx])
        img = tf.keras.preprocessing.image.load_img(
            image_path, target_size=(224, 224)
        )
        plt.subplot(2, 5, i + 1)
        plt.imshow(img)
        plt.axis("off")
        pred_label = class_names[y_pred[idx]]
        true_label = class_names[y_true[idx]]
        plt.title(f"Pred: {pred_label}\nTrue: {true_label}", color="red" if pred_label != true_label else "green")
    plt.suptitle("Wrong Predictions", fontsize=20)
    plt.tight_layout()
    plt.savefig(f"{output_dir}/{model_name}_wrong_predictiond.png")
    plt.show()

2.4. Model 1 - MobileNetV2¶

2.4.0. Load Datasets¶

In [9]:
# Load datasets
train_mnv2 = load_image_dataset(train_dir)
val_mnv2 = load_image_dataset(val_dir)
test_mnv2 = load_image_dataset(test_dir, shuffle=False)
Found 3780 images belonging to 90 classes.
Found 810 images belonging to 90 classes.
Found 810 images belonging to 90 classes.

2.4.1. Build / Train MobileNetV2 Model¶

Model Reference: MobileNetV2

Source: https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5

In [10]:
# Load model
model_mnv2 = build_model(
    "mobilenet",
    num_classes=train_mnv2.num_classes,
    compile_params=compile_params
)

# Train model
early_stop = get_early_stopping_callback()
model_check_point = get_check_point_callback(
    checkpoint_path=f"{checkpoint_dir}/MobileNetV2CheckPoint.keras"
)
history_mnv2 = model_mnv2.fit(
    train_mnv2,
    validation_data=val_mnv2,
    epochs=10,
    callbacks=[early_stop, model_check_point]
)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
9406464/9406464 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
/usr/local/lib/python3.11/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
  self._warn_if_super_not_called()
Epoch 1/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 204ms/step - accuracy: 0.3334 - loss: 3.1947
Epoch 1: val_loss improved from inf to 0.80063, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 48s 303ms/step - accuracy: 0.3353 - loss: 3.1853 - val_accuracy: 0.8160 - val_loss: 0.8006
Epoch 2/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 168ms/step - accuracy: 0.9114 - loss: 0.4832
Epoch 2: val_loss improved from 0.80063 to 0.58417, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 25s 214ms/step - accuracy: 0.9114 - loss: 0.4829 - val_accuracy: 0.8617 - val_loss: 0.5842
Epoch 3/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 172ms/step - accuracy: 0.9731 - loss: 0.2392
Epoch 3: val_loss improved from 0.58417 to 0.50797, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 26s 217ms/step - accuracy: 0.9731 - loss: 0.2392 - val_accuracy: 0.8605 - val_loss: 0.5080
Epoch 4/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 180ms/step - accuracy: 0.9897 - loss: 0.1441
Epoch 4: val_loss improved from 0.50797 to 0.46614, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 27s 227ms/step - accuracy: 0.9897 - loss: 0.1440 - val_accuracy: 0.8716 - val_loss: 0.4661
Epoch 5/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 172ms/step - accuracy: 0.9959 - loss: 0.0947
Epoch 5: val_loss improved from 0.46614 to 0.44359, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 26s 217ms/step - accuracy: 0.9959 - loss: 0.0947 - val_accuracy: 0.8790 - val_loss: 0.4436
Epoch 6/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 173ms/step - accuracy: 0.9990 - loss: 0.0649
Epoch 6: val_loss improved from 0.44359 to 0.42028, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 26s 216ms/step - accuracy: 0.9990 - loss: 0.0649 - val_accuracy: 0.8815 - val_loss: 0.4203
Epoch 7/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 178ms/step - accuracy: 0.9990 - loss: 0.0485
Epoch 7: val_loss improved from 0.42028 to 0.41121, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 27s 227ms/step - accuracy: 0.9990 - loss: 0.0486 - val_accuracy: 0.8840 - val_loss: 0.4112
Epoch 8/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 170ms/step - accuracy: 0.9990 - loss: 0.0398
Epoch 8: val_loss improved from 0.41121 to 0.41068, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 26s 218ms/step - accuracy: 0.9990 - loss: 0.0398 - val_accuracy: 0.8889 - val_loss: 0.4107
Epoch 9/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 170ms/step - accuracy: 0.9997 - loss: 0.0330
Epoch 9: val_loss improved from 0.41068 to 0.40021, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 26s 218ms/step - accuracy: 0.9997 - loss: 0.0330 - val_accuracy: 0.8901 - val_loss: 0.4002
Epoch 10/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 173ms/step - accuracy: 0.9994 - loss: 0.0265
Epoch 10: val_loss improved from 0.40021 to 0.39781, saving model to ./cis730_term_project/checkpoints/MobileNetV2CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 26s 219ms/step - accuracy: 0.9994 - loss: 0.0265 - val_accuracy: 0.8864 - val_loss: 0.3978
Restoring model weights from the end of the best epoch: 10.
In [11]:
# Plot loss/accuracy
plot_training_loss_accuracy(history_mnv2)
No description has been provided for this image

2.4.2. Fine Tune MobileNetV2 Model¶

In [12]:
model_mnv2.trainable = True
model_mnv2.summary(show_trainable=True)

model_mnv2.compile(**compile_params)

history_mnv2 = model_mnv2.fit(
    train_mnv2,
    validation_data=val_mnv2,
    epochs=10,
    callbacks=[early_stop, model_check_point]
)
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Layer (type)                ┃ Output Shape          ┃    Param # ┃ Trai… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
│ mobilenetv2_1.00_224        │ (None, 7, 7, 1280)    │  2,257,984 │   N   │
│ (Functional)                │                       │            │       │
├─────────────────────────────┼───────────────────────┼────────────┼───────┤
│ global_average_pooling2d    │ (None, 1280)          │          0 │   -   │
│ (GlobalAveragePooling2D)    │                       │            │       │
├─────────────────────────────┼───────────────────────┼────────────┼───────┤
│ dense (Dense)               │ (None, 90)            │    115,290 │   Y   │
└─────────────────────────────┴───────────────────────┴────────────┴───────┘
 Total params: 2,603,856 (9.93 MB)
 Trainable params: 115,290 (450.35 KB)
 Non-trainable params: 2,257,984 (8.61 MB)
 Optimizer params: 230,582 (900.71 KB)
Epoch 1/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 183ms/step - accuracy: 0.9980 - loss: 0.0316
Epoch 1: val_loss did not improve from 0.39781
119/119 ━━━━━━━━━━━━━━━━━━━━ 38s 267ms/step - accuracy: 0.9980 - loss: 0.0316 - val_accuracy: 0.8864 - val_loss: 0.4187
Epoch 2/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 175ms/step - accuracy: 0.9997 - loss: 0.0136
Epoch 2: val_loss did not improve from 0.39781
119/119 ━━━━━━━━━━━━━━━━━━━━ 25s 214ms/step - accuracy: 0.9997 - loss: 0.0136 - val_accuracy: 0.8889 - val_loss: 0.4052
Epoch 3/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 173ms/step - accuracy: 0.9997 - loss: 0.0082
Epoch 3: val_loss did not improve from 0.39781
119/119 ━━━━━━━━━━━━━━━━━━━━ 25s 211ms/step - accuracy: 0.9997 - loss: 0.0082 - val_accuracy: 0.8938 - val_loss: 0.4105
Epoch 4/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 171ms/step - accuracy: 0.9999 - loss: 0.0061
Epoch 4: val_loss did not improve from 0.39781
119/119 ━━━━━━━━━━━━━━━━━━━━ 25s 209ms/step - accuracy: 0.9999 - loss: 0.0061 - val_accuracy: 0.8877 - val_loss: 0.3982
Epoch 5/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 173ms/step - accuracy: 0.9999 - loss: 0.0043
Epoch 5: val_loss did not improve from 0.39781
119/119 ━━━━━━━━━━━━━━━━━━━━ 25s 211ms/step - accuracy: 0.9999 - loss: 0.0043 - val_accuracy: 0.8889 - val_loss: 0.3998
Epoch 6/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 171ms/step - accuracy: 0.9997 - loss: 0.0044
Epoch 6: val_loss did not improve from 0.39781
119/119 ━━━━━━━━━━━━━━━━━━━━ 25s 210ms/step - accuracy: 0.9997 - loss: 0.0044 - val_accuracy: 0.8914 - val_loss: 0.4006
Epoch 7/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 170ms/step - accuracy: 0.9991 - loss: 0.0047
Epoch 7: val_loss did not improve from 0.39781
119/119 ━━━━━━━━━━━━━━━━━━━━ 25s 208ms/step - accuracy: 0.9991 - loss: 0.0047 - val_accuracy: 0.8877 - val_loss: 0.4043
Epoch 7: early stopping
Restoring model weights from the end of the best epoch: 4.

2.4.3. Accuracy Report¶

Get Predictions¶
In [13]:
y_true_mnv2, y_pred_mnv2 = get_predictions(test_mnv2, model_mnv2)
26/26 ━━━━━━━━━━━━━━━━━━━━ 11s 276ms/step
Classification Report¶
In [14]:
print(classification_report(y_true_mnv2, y_pred_mnv2, target_names=test_mnv2.class_indices.keys()))
                precision    recall  f1-score   support

      antelope       1.00      0.78      0.88         9
        badger       0.90      1.00      0.95         9
           bat       0.71      0.56      0.62         9
          bear       1.00      0.89      0.94         9
           bee       0.90      1.00      0.95         9
        beetle       1.00      0.89      0.94         9
         bison       1.00      1.00      1.00         9
          boar       1.00      1.00      1.00         9
     butterfly       0.90      1.00      0.95         9
           cat       0.90      1.00      0.95         9
   caterpillar       0.86      0.67      0.75         9
    chimpanzee       0.90      1.00      0.95         9
     cockroach       0.78      0.78      0.78         9
           cow       0.88      0.78      0.82         9
        coyote       0.73      0.89      0.80         9
          crab       1.00      1.00      1.00         9
          crow       0.80      0.89      0.84         9
          deer       0.62      0.89      0.73         9
           dog       0.80      0.89      0.84         9
       dolphin       0.80      0.89      0.84         9
        donkey       0.64      0.78      0.70         9
     dragonfly       1.00      0.89      0.94         9
          duck       1.00      0.89      0.94         9
         eagle       0.89      0.89      0.89         9
      elephant       1.00      0.89      0.94         9
      flamingo       1.00      1.00      1.00         9
           fly       0.89      0.89      0.89         9
           fox       1.00      0.78      0.88         9
          goat       0.78      0.78      0.78         9
      goldfish       1.00      1.00      1.00         9
         goose       0.82      1.00      0.90         9
       gorilla       1.00      0.78      0.88         9
   grasshopper       0.90      1.00      0.95         9
       hamster       0.82      1.00      0.90         9
          hare       0.86      0.67      0.75         9
      hedgehog       0.90      1.00      0.95         9
  hippopotamus       0.75      1.00      0.86         9
      hornbill       1.00      0.89      0.94         9
         horse       0.75      0.67      0.71         9
   hummingbird       1.00      1.00      1.00         9
         hyena       1.00      0.89      0.94         9
     jellyfish       0.89      0.89      0.89         9
      kangaroo       0.89      0.89      0.89         9
         koala       0.82      1.00      0.90         9
      ladybugs       1.00      0.89      0.94         9
       leopard       1.00      0.89      0.94         9
          lion       1.00      0.89      0.94         9
        lizard       0.78      0.78      0.78         9
       lobster       1.00      1.00      1.00         9
      mosquito       0.75      0.67      0.71         9
          moth       0.67      0.89      0.76         9
         mouse       0.67      0.44      0.53         9
       octopus       0.57      0.89      0.70         9
         okapi       1.00      0.78      0.88         9
     orangutan       1.00      1.00      1.00         9
         otter       0.73      0.89      0.80         9
           owl       0.75      0.67      0.71         9
            ox       0.67      0.67      0.67         9
        oyster       1.00      0.78      0.88         9
         panda       1.00      1.00      1.00         9
        parrot       1.00      1.00      1.00         9
pelecaniformes       1.00      1.00      1.00         9
       penguin       1.00      1.00      1.00         9
           pig       1.00      1.00      1.00         9
        pigeon       0.88      0.78      0.82         9
     porcupine       0.90      1.00      0.95         9
        possum       0.75      0.67      0.71         9
       raccoon       1.00      0.89      0.94         9
           rat       0.71      0.56      0.62         9
      reindeer       0.80      0.89      0.84         9
    rhinoceros       0.78      0.78      0.78         9
     sandpiper       1.00      1.00      1.00         9
      seahorse       0.90      1.00      0.95         9
          seal       0.88      0.78      0.82         9
         shark       1.00      1.00      1.00         9
         sheep       0.78      0.78      0.78         9
         snake       0.73      0.89      0.80         9
       sparrow       1.00      1.00      1.00         9
         squid       0.88      0.78      0.82         9
      squirrel       0.67      0.67      0.67         9
      starfish       0.89      0.89      0.89         9
          swan       1.00      1.00      1.00         9
         tiger       1.00      0.89      0.94         9
        turkey       1.00      1.00      1.00         9
        turtle       1.00      0.89      0.94         9
         whale       1.00      0.78      0.88         9
          wolf       1.00      0.89      0.94         9
        wombat       0.82      1.00      0.90         9
    woodpecker       0.90      1.00      0.95         9
         zebra       1.00      1.00      1.00         9

      accuracy                           0.88       810
     macro avg       0.89      0.88      0.88       810
  weighted avg       0.89      0.88      0.88       810

Best / Worst Predictions¶
In [15]:
class_names =  list(test_mnv2.class_indices.keys())

best_5_mnv2, worst_5_mnv2 = get_best_worst(
    y_true_mnv2, y_pred_mnv2, class_names
)
Best 5 classes (F1 Score):
['shark', 'sandpiper', 'turkey', 'swan', 'zebra']

Worst 5 classes (F1 Score):
['mouse', 'bat', 'rat', 'ox', 'squirrel']
Confusion Matrix¶
In [16]:
plot_confusion_for_classes(
    y_true_mnv2, y_pred_mnv2, best_5_mnv2, worst_5_mnv2,
    class_names, model_name = "mobilenet"
)
No description has been provided for this image
Wrong Predictions¶
In [17]:
# Get wrong predictions
wrong_indices_mnv2 = np.where(y_true_mnv2 != y_pred_mnv2)[0]
print(f"Found {len(wrong_indices_mnv2)} misclassified images.")

# plot
show_wrong_predictions(y_true_mnv2, y_pred_mnv2, wrong_indices_mnv2, test_dir, model_name = "mobilenet")
Found 98 misclassified images.
No description has been provided for this image
Save Predictions¶
In [18]:
# Save Predictions
df_preds_mnv2 = pd.DataFrame({
    'filename': [str(os.path.basename(p)) for p in test_mnv2.filenames],
    'true_label': [class_names[i] for i in y_true_mnv2],
    'predicted_label': [class_names[i] for i in y_pred_mnv2],
    'correct': (y_true_mnv2 == y_pred_mnv2)
})

df_preds_mnv2.to_csv(f"{results_dir}/predictions_MobileNetV2.csv", index=False)

2.5. Model 2 - EfficientNetB3¶

2.5.0. Load Datasets¶

In [19]:
# Load datasets
train_enB3 = load_image_dataset(train_dir, image_size=(300, 300), model_name = "efficientnet")
val_enB3 = load_image_dataset(val_dir, image_size=(300, 300), model_name = "efficientnet")
test_enB3 = load_image_dataset(test_dir, image_size=(300, 300), shuffle=False, model_name = "efficientnet")
Found 3780 images belonging to 90 classes.
Found 810 images belonging to 90 classes.
Found 810 images belonging to 90 classes.

2.5.1. Build / Train EfficientNetB3 Model¶

Model Reference: [https://keras.io/api/applications/efficientnet/#efficientnetb3-function)

Source: https://storage.googleapis.com/keras-applications/efficientnetb3_notop.h5

In [20]:
compile_params.update(optimizer=tf.keras.optimizers.Adam(0.0005))

# Load model
model_enB3 = build_model(
    "efficientnet",
    num_classes=train_enB3.num_classes,
    compile_params=compile_params
)

early_stop = get_early_stopping_callback()
model_check_point = get_check_point_callback(
    checkpoint_path=f"{checkpoint_dir}/EfficientNetB3CheckPoint.keras"
)

# Train model
history_enB3 = model_enB3.fit(
    train_enB3,
    validation_data=val_enB3,
    epochs=10,
    callbacks=[early_stop, model_check_point]
)
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb3_notop.h5
43941136/43941136 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
/usr/local/lib/python3.11/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
  self._warn_if_super_not_called()
Epoch 1/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 370ms/step - accuracy: 0.3454 - loss: 3.7448
Epoch 1: val_loss improved from inf to 1.63013, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 105s 575ms/step - accuracy: 0.3475 - loss: 3.7384 - val_accuracy: 0.8654 - val_loss: 1.6301
Epoch 2/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 240ms/step - accuracy: 0.8922 - loss: 1.2590
Epoch 2: val_loss improved from 1.63013 to 0.73234, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 36s 303ms/step - accuracy: 0.8923 - loss: 1.2571 - val_accuracy: 0.9099 - val_loss: 0.7323
Epoch 3/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 232ms/step - accuracy: 0.9326 - loss: 0.5671
Epoch 3: val_loss improved from 0.73234 to 0.49366, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 35s 297ms/step - accuracy: 0.9326 - loss: 0.5667 - val_accuracy: 0.9210 - val_loss: 0.4937
Epoch 4/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 233ms/step - accuracy: 0.9522 - loss: 0.3680
Epoch 4: val_loss improved from 0.49366 to 0.38916, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 35s 297ms/step - accuracy: 0.9522 - loss: 0.3679 - val_accuracy: 0.9272 - val_loss: 0.3892
Epoch 5/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 226ms/step - accuracy: 0.9620 - loss: 0.2756
Epoch 5: val_loss improved from 0.38916 to 0.33032, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 38s 322ms/step - accuracy: 0.9620 - loss: 0.2755 - val_accuracy: 0.9370 - val_loss: 0.3303
Epoch 6/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 236ms/step - accuracy: 0.9737 - loss: 0.2094
Epoch 6: val_loss improved from 0.33032 to 0.29512, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 36s 302ms/step - accuracy: 0.9737 - loss: 0.2094 - val_accuracy: 0.9383 - val_loss: 0.2951
Epoch 7/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 233ms/step - accuracy: 0.9859 - loss: 0.1652
Epoch 7: val_loss improved from 0.29512 to 0.26861, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 36s 299ms/step - accuracy: 0.9859 - loss: 0.1653 - val_accuracy: 0.9444 - val_loss: 0.2686
Epoch 8/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 229ms/step - accuracy: 0.9826 - loss: 0.1436
Epoch 8: val_loss improved from 0.26861 to 0.25104, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 39s 325ms/step - accuracy: 0.9826 - loss: 0.1435 - val_accuracy: 0.9407 - val_loss: 0.2510
Epoch 9/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 236ms/step - accuracy: 0.9893 - loss: 0.1136
Epoch 9: val_loss improved from 0.25104 to 0.23854, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 36s 301ms/step - accuracy: 0.9893 - loss: 0.1136 - val_accuracy: 0.9457 - val_loss: 0.2385
Epoch 10/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 237ms/step - accuracy: 0.9917 - loss: 0.1023
Epoch 10: val_loss improved from 0.23854 to 0.22661, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 36s 302ms/step - accuracy: 0.9917 - loss: 0.1023 - val_accuracy: 0.9457 - val_loss: 0.2266
Restoring model weights from the end of the best epoch: 10.
In [21]:
# Plot loss/accuracy
plot_training_loss_accuracy(history_enB3, model_name="efficientnet")
No description has been provided for this image

2.5.2. Fine Tune EfficientNetB3 Model¶

In [22]:
model_enB3.trainable = True
model_enB3.summary(show_trainable=True)

model_enB3.compile(**compile_params)

history_enB3 = model_enB3.fit(
    train_enB3,
    validation_data=val_enB3,
    epochs=10,
    callbacks=[early_stop, model_check_point]
)
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Layer (type)                ┃ Output Shape          ┃    Param # ┃ Trai… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
│ efficientnetb3 (Functional) │ (None, 10, 10, 1536)  │ 10,783,535 │   N   │
├─────────────────────────────┼───────────────────────┼────────────┼───────┤
│ global_average_pooling2d_1  │ (None, 1536)          │          0 │   -   │
│ (GlobalAveragePooling2D)    │                       │            │       │
├─────────────────────────────┼───────────────────────┼────────────┼───────┤
│ dense_1 (Dense)             │ (None, 90)            │    138,330 │   Y   │
└─────────────────────────────┴───────────────────────┴────────────┴───────┘
 Total params: 11,198,527 (42.72 MB)
 Trainable params: 138,330 (540.35 KB)
 Non-trainable params: 10,783,535 (41.14 MB)
 Optimizer params: 276,662 (1.06 MB)
Epoch 1/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 355ms/step - accuracy: 0.9922 - loss: 0.0899
Epoch 1: val_loss improved from 0.22661 to 0.21912, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 99s 573ms/step - accuracy: 0.9922 - loss: 0.0899 - val_accuracy: 0.9457 - val_loss: 0.2191
Epoch 2/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 228ms/step - accuracy: 0.9930 - loss: 0.0785
Epoch 2: val_loss improved from 0.21912 to 0.21307, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 35s 293ms/step - accuracy: 0.9930 - loss: 0.0785 - val_accuracy: 0.9420 - val_loss: 0.2131
Epoch 3/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 233ms/step - accuracy: 0.9957 - loss: 0.0682
Epoch 3: val_loss improved from 0.21307 to 0.20707, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 39s 329ms/step - accuracy: 0.9957 - loss: 0.0682 - val_accuracy: 0.9420 - val_loss: 0.2071
Epoch 4/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 230ms/step - accuracy: 0.9955 - loss: 0.0611
Epoch 4: val_loss improved from 0.20707 to 0.20177, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 35s 297ms/step - accuracy: 0.9955 - loss: 0.0611 - val_accuracy: 0.9432 - val_loss: 0.2018
Epoch 5/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 231ms/step - accuracy: 0.9969 - loss: 0.0560
Epoch 5: val_loss improved from 0.20177 to 0.19720, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 39s 327ms/step - accuracy: 0.9969 - loss: 0.0560 - val_accuracy: 0.9457 - val_loss: 0.1972
Epoch 6/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 244ms/step - accuracy: 0.9972 - loss: 0.0479
Epoch 6: val_loss improved from 0.19720 to 0.19328, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 37s 311ms/step - accuracy: 0.9972 - loss: 0.0479 - val_accuracy: 0.9432 - val_loss: 0.1933
Epoch 7/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 228ms/step - accuracy: 0.9983 - loss: 0.0442
Epoch 7: val_loss improved from 0.19328 to 0.18986, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 35s 293ms/step - accuracy: 0.9983 - loss: 0.0442 - val_accuracy: 0.9506 - val_loss: 0.1899
Epoch 8/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 233ms/step - accuracy: 0.9976 - loss: 0.0400
Epoch 8: val_loss improved from 0.18986 to 0.18791, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 36s 301ms/step - accuracy: 0.9976 - loss: 0.0400 - val_accuracy: 0.9457 - val_loss: 0.1879
Epoch 9/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 223ms/step - accuracy: 0.9986 - loss: 0.0365
Epoch 9: val_loss improved from 0.18791 to 0.18633, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 34s 285ms/step - accuracy: 0.9986 - loss: 0.0365 - val_accuracy: 0.9444 - val_loss: 0.1863
Epoch 10/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 244ms/step - accuracy: 0.9960 - loss: 0.0359
Epoch 10: val_loss improved from 0.18633 to 0.18304, saving model to ./cis730_term_project/checkpoints/EfficientNetB3CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 44s 312ms/step - accuracy: 0.9960 - loss: 0.0359 - val_accuracy: 0.9457 - val_loss: 0.1830
Restoring model weights from the end of the best epoch: 10.

2.5.3. Accuracy Report¶

Get Predictions¶
In [23]:
# Get Predictions
y_true_enB3, y_pred_enB3 = get_predictions(test_enB3, model_enB3)
26/26 ━━━━━━━━━━━━━━━━━━━━ 24s 587ms/step
Classification Report¶
In [24]:
print(classification_report(
    y_true_enB3, y_pred_enB3, target_names=test_enB3.class_indices.keys()
))
                precision    recall  f1-score   support

      antelope       1.00      0.89      0.94         9
        badger       1.00      1.00      1.00         9
           bat       0.88      0.78      0.82         9
          bear       1.00      1.00      1.00         9
           bee       1.00      1.00      1.00         9
        beetle       1.00      0.89      0.94         9
         bison       1.00      1.00      1.00         9
          boar       1.00      1.00      1.00         9
     butterfly       0.82      1.00      0.90         9
           cat       0.82      1.00      0.90         9
   caterpillar       1.00      0.78      0.88         9
    chimpanzee       1.00      1.00      1.00         9
     cockroach       1.00      1.00      1.00         9
           cow       1.00      0.78      0.88         9
        coyote       0.89      0.89      0.89         9
          crab       1.00      1.00      1.00         9
          crow       1.00      1.00      1.00         9
          deer       0.90      1.00      0.95         9
           dog       0.90      1.00      0.95         9
       dolphin       1.00      0.89      0.94         9
        donkey       0.88      0.78      0.82         9
     dragonfly       1.00      0.89      0.94         9
          duck       1.00      1.00      1.00         9
         eagle       0.90      1.00      0.95         9
      elephant       1.00      0.89      0.94         9
      flamingo       1.00      1.00      1.00         9
           fly       1.00      0.89      0.94         9
           fox       1.00      1.00      1.00         9
          goat       0.78      0.78      0.78         9
      goldfish       1.00      1.00      1.00         9
         goose       1.00      1.00      1.00         9
       gorilla       1.00      1.00      1.00         9
   grasshopper       1.00      1.00      1.00         9
       hamster       0.90      1.00      0.95         9
          hare       1.00      0.89      0.94         9
      hedgehog       0.90      1.00      0.95         9
  hippopotamus       1.00      1.00      1.00         9
      hornbill       1.00      0.89      0.94         9
         horse       0.80      0.89      0.84         9
   hummingbird       1.00      1.00      1.00         9
         hyena       1.00      1.00      1.00         9
     jellyfish       0.90      1.00      0.95         9
      kangaroo       1.00      1.00      1.00         9
         koala       1.00      1.00      1.00         9
      ladybugs       0.90      1.00      0.95         9
       leopard       1.00      1.00      1.00         9
          lion       1.00      1.00      1.00         9
        lizard       1.00      1.00      1.00         9
       lobster       1.00      1.00      1.00         9
      mosquito       0.90      1.00      0.95         9
          moth       0.88      0.78      0.82         9
         mouse       0.62      0.56      0.59         9
       octopus       0.73      0.89      0.80         9
         okapi       1.00      1.00      1.00         9
     orangutan       1.00      1.00      1.00         9
         otter       1.00      1.00      1.00         9
           owl       1.00      0.89      0.94         9
            ox       1.00      0.89      0.94         9
        oyster       1.00      1.00      1.00         9
         panda       1.00      1.00      1.00         9
        parrot       1.00      1.00      1.00         9
pelecaniformes       1.00      1.00      1.00         9
       penguin       1.00      1.00      1.00         9
           pig       1.00      1.00      1.00         9
        pigeon       1.00      1.00      1.00         9
     porcupine       0.89      0.89      0.89         9
        possum       0.89      0.89      0.89         9
       raccoon       1.00      1.00      1.00         9
           rat       0.71      0.56      0.62         9
      reindeer       0.90      1.00      0.95         9
    rhinoceros       0.82      1.00      0.90         9
     sandpiper       1.00      1.00      1.00         9
      seahorse       0.90      1.00      0.95         9
          seal       1.00      1.00      1.00         9
         shark       0.90      1.00      0.95         9
         sheep       1.00      0.89      0.94         9
         snake       1.00      1.00      1.00         9
       sparrow       1.00      1.00      1.00         9
         squid       1.00      0.89      0.94         9
      squirrel       0.69      1.00      0.82         9
      starfish       1.00      0.89      0.94         9
          swan       1.00      1.00      1.00         9
         tiger       1.00      0.89      0.94         9
        turkey       1.00      1.00      1.00         9
        turtle       1.00      1.00      1.00         9
         whale       1.00      1.00      1.00         9
          wolf       0.89      0.89      0.89         9
        wombat       1.00      1.00      1.00         9
    woodpecker       1.00      1.00      1.00         9
         zebra       1.00      1.00      1.00         9

      accuracy                           0.95       810
     macro avg       0.95      0.95      0.95       810
  weighted avg       0.95      0.95      0.95       810

Best / Worst Predictions¶
In [25]:
class_names =  list(test_enB3.class_indices.keys())

best_5_enB3, worst_5_enB3 = get_best_worst(
    y_true_enB3, y_pred_enB3, class_names
)
Best 5 classes (F1 Score):
['turtle', 'whale', 'wombat', 'woodpecker', 'zebra']

Worst 5 classes (F1 Score):
['mouse', 'rat', 'goat', 'octopus', 'squirrel']
Confusion Matrix¶
In [26]:
plot_confusion_for_classes(
    y_true_enB3, y_pred_enB3, best_5_enB3, worst_5_enB3,
    class_names, model_name = "efficientnet"
)
No description has been provided for this image
Wrong Predictions¶
In [27]:
# Get wrong predictions
wrong_indices_enB3 = np.where(y_true_enB3 != y_pred_enB3)[0]
print(f"Found {len(wrong_indices_enB3)} misclassified images.")

# plot
show_wrong_predictions(
    y_true_enB3, y_pred_enB3, wrong_indices_enB3,
    test_dir, model_name = "efficientnet"
)
Found 40 misclassified images.
No description has been provided for this image
Save Predictions¶
In [28]:
# Save Predictions
df_preds_enB3 = pd.DataFrame({
    'filename': [str(os.path.basename(p)) for p in test_enB3.filenames],
    'true_label': [class_names[i] for i in y_true_enB3],
    'predicted_label': [class_names[i] for i in y_pred_enB3],
    'correct': (y_true_enB3 == y_pred_enB3)
})

df_preds_enB3.to_csv(f"{results_dir}/predictions_EfficientNetB3.csv", index=False)
In [29]:
train_dir
Out[29]:
'./cis730_term_project/data/train'

2.6. Model 3 - ResNet50¶

2.6.0. Load Datasets¶

In [30]:
# Load datasets
train_rn50 = load_image_dataset(train_dir, model_name ="resnet")
val_rn50 = load_image_dataset(val_dir, model_name = "resnet")
test_rn50 = load_image_dataset(test_dir, shuffle=False, model_name = "resnet")
Found 3780 images belonging to 90 classes.
Found 810 images belonging to 90 classes.
Found 810 images belonging to 90 classes.

2.6.1. Build / Train ResNet50 Model¶

Model Reference: [https://keras.io/api/applications/resnet/#resnet50-function)

Source: https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5

In [31]:
compile_params.update(optimizer=tf.keras.optimizers.Adam(0.0005))

# Load model
model_rn50 = build_model(
    "resnet",
    num_classes=train_rn50.num_classes,
    compile_params=compile_params
)

early_stop_rn50 = get_early_stopping_callback()
model_check_point_rn50 = get_check_point_callback(
    checkpoint_path=f"{checkpoint_dir}/ResNet50CheckPoint.keras"
)

# Train model
history_rn50 = model_rn50.fit(
    train_rn50,
    validation_data=val_rn50,
    epochs=10,
    callbacks=[early_stop_rn50, model_check_point_rn50]
)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94765736/94765736 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
/usr/local/lib/python3.11/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
  self._warn_if_super_not_called()
Epoch 1/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 228ms/step - accuracy: 0.2252 - loss: 3.6815
Epoch 1: val_loss improved from inf to 1.17967, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 53s 344ms/step - accuracy: 0.2269 - loss: 3.6725 - val_accuracy: 0.7481 - val_loss: 1.1797
Epoch 2/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 199ms/step - accuracy: 0.8646 - loss: 0.7711
Epoch 2: val_loss improved from 1.17967 to 0.72392, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 66s 250ms/step - accuracy: 0.8647 - loss: 0.7703 - val_accuracy: 0.8395 - val_loss: 0.7239
Epoch 3/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 201ms/step - accuracy: 0.9500 - loss: 0.3734
Epoch 3: val_loss improved from 0.72392 to 0.57050, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 31s 256ms/step - accuracy: 0.9500 - loss: 0.3732 - val_accuracy: 0.8753 - val_loss: 0.5705
Epoch 4/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 194ms/step - accuracy: 0.9832 - loss: 0.2141
Epoch 4: val_loss improved from 0.57050 to 0.49737, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 29s 244ms/step - accuracy: 0.9832 - loss: 0.2141 - val_accuracy: 0.8815 - val_loss: 0.4974
Epoch 5/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 195ms/step - accuracy: 0.9913 - loss: 0.1458
Epoch 5: val_loss improved from 0.49737 to 0.46535, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 41s 248ms/step - accuracy: 0.9914 - loss: 0.1458 - val_accuracy: 0.8864 - val_loss: 0.4654
Epoch 6/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 192ms/step - accuracy: 0.9959 - loss: 0.1051
Epoch 6: val_loss improved from 0.46535 to 0.42466, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 34s 287ms/step - accuracy: 0.9959 - loss: 0.1051 - val_accuracy: 0.8938 - val_loss: 0.4247
Epoch 7/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 200ms/step - accuracy: 0.9965 - loss: 0.0768
Epoch 7: val_loss improved from 0.42466 to 0.40783, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 37s 250ms/step - accuracy: 0.9965 - loss: 0.0768 - val_accuracy: 0.8951 - val_loss: 0.4078
Epoch 8/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 196ms/step - accuracy: 0.9981 - loss: 0.0578
Epoch 8: val_loss improved from 0.40783 to 0.39584, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 30s 251ms/step - accuracy: 0.9981 - loss: 0.0578 - val_accuracy: 0.8951 - val_loss: 0.3958
Epoch 9/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 196ms/step - accuracy: 0.9986 - loss: 0.0470
Epoch 9: val_loss improved from 0.39584 to 0.38877, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 30s 252ms/step - accuracy: 0.9986 - loss: 0.0470 - val_accuracy: 0.9012 - val_loss: 0.3888
Epoch 10/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 198ms/step - accuracy: 0.9991 - loss: 0.0386
Epoch 10: val_loss improved from 0.38877 to 0.38194, saving model to ./cis730_term_project/checkpoints/ResNet50CheckPoint.keras
119/119 ━━━━━━━━━━━━━━━━━━━━ 35s 293ms/step - accuracy: 0.9991 - loss: 0.0386 - val_accuracy: 0.8988 - val_loss: 0.3819
Restoring model weights from the end of the best epoch: 10.
In [32]:
# Plot loss/accuracy
plot_training_loss_accuracy(history_rn50, model_name="resnet")
No description has been provided for this image

2.6.2. Fine Tune ResNet50 Model¶

In [33]:
model_rn50.trainable = True
model_rn50.summary(show_trainable=True)

model_rn50.compile(**compile_params)

history_rn50 = model_rn50.fit(
    train_rn50,
    validation_data=val_rn50,
    epochs=10,
    callbacks=[early_stop, model_check_point]
)
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Layer (type)                ┃ Output Shape          ┃    Param # ┃ Trai… ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
│ resnet50 (Functional)       │ (None, 7, 7, 2048)    │ 23,587,712 │   N   │
├─────────────────────────────┼───────────────────────┼────────────┼───────┤
│ global_average_pooling2d_2  │ (None, 2048)          │          0 │   -   │
│ (GlobalAveragePooling2D)    │                       │            │       │
├─────────────────────────────┼───────────────────────┼────────────┼───────┤
│ dense_2 (Dense)             │ (None, 90)            │    184,410 │   Y   │
└─────────────────────────────┴───────────────────────┴────────────┴───────┘
 Total params: 24,140,944 (92.09 MB)
 Trainable params: 184,410 (720.35 KB)
 Non-trainable params: 23,587,712 (89.98 MB)
 Optimizer params: 368,822 (1.41 MB)
Epoch 1/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 215ms/step - accuracy: 0.9993 - loss: 0.0331
Epoch 1: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 47s 313ms/step - accuracy: 0.9993 - loss: 0.0331 - val_accuracy: 0.9037 - val_loss: 0.3729
Epoch 2/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 197ms/step - accuracy: 1.0000 - loss: 0.0281
Epoch 2: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 29s 242ms/step - accuracy: 0.9999 - loss: 0.0281 - val_accuracy: 0.8988 - val_loss: 0.3643
Epoch 3/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 194ms/step - accuracy: 0.9998 - loss: 0.0233
Epoch 3: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 34s 282ms/step - accuracy: 0.9998 - loss: 0.0233 - val_accuracy: 0.8938 - val_loss: 0.3638
Epoch 4/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 192ms/step - accuracy: 0.9999 - loss: 0.0206
Epoch 4: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 28s 238ms/step - accuracy: 0.9999 - loss: 0.0206 - val_accuracy: 0.8988 - val_loss: 0.3603
Epoch 5/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 197ms/step - accuracy: 0.9999 - loss: 0.0188
Epoch 5: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 29s 247ms/step - accuracy: 0.9999 - loss: 0.0188 - val_accuracy: 0.8951 - val_loss: 0.3573
Epoch 6/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 197ms/step - accuracy: 0.9999 - loss: 0.0164
Epoch 6: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 29s 247ms/step - accuracy: 0.9999 - loss: 0.0164 - val_accuracy: 0.8988 - val_loss: 0.3527
Epoch 7/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 197ms/step - accuracy: 0.9999 - loss: 0.0138
Epoch 7: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 29s 242ms/step - accuracy: 0.9999 - loss: 0.0138 - val_accuracy: 0.9000 - val_loss: 0.3520
Epoch 8/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 197ms/step - accuracy: 0.9996 - loss: 0.0143
Epoch 8: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 29s 247ms/step - accuracy: 0.9996 - loss: 0.0143 - val_accuracy: 0.9037 - val_loss: 0.3483
Epoch 9/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 194ms/step - accuracy: 1.0000 - loss: 0.0113
Epoch 9: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 28s 239ms/step - accuracy: 1.0000 - loss: 0.0113 - val_accuracy: 0.9012 - val_loss: 0.3471
Epoch 10/10
119/119 ━━━━━━━━━━━━━━━━━━━━ 0s 197ms/step - accuracy: 0.9992 - loss: 0.0125
Epoch 10: val_loss did not improve from 0.18304
119/119 ━━━━━━━━━━━━━━━━━━━━ 29s 240ms/step - accuracy: 0.9992 - loss: 0.0125 - val_accuracy: 0.9012 - val_loss: 0.3488
Restoring model weights from the end of the best epoch: 9.

2.6.3. Accuracy Report¶

Get Predictions¶
In [34]:
# Get Predictions
y_true_rn50, y_pred_rn50 = get_predictions(test_rn50, model_rn50)
26/26 ━━━━━━━━━━━━━━━━━━━━ 13s 327ms/step
Classification Report¶
In [35]:
print(classification_report(
    y_true_rn50, y_pred_rn50, target_names=test_rn50.class_indices.keys()
))
                precision    recall  f1-score   support

      antelope       0.73      0.89      0.80         9
        badger       1.00      1.00      1.00         9
           bat       0.75      0.67      0.71         9
          bear       1.00      1.00      1.00         9
           bee       0.90      1.00      0.95         9
        beetle       1.00      1.00      1.00         9
         bison       0.90      1.00      0.95         9
          boar       0.82      1.00      0.90         9
     butterfly       1.00      1.00      1.00         9
           cat       0.90      1.00      0.95         9
   caterpillar       0.73      0.89      0.80         9
    chimpanzee       1.00      1.00      1.00         9
     cockroach       1.00      1.00      1.00         9
           cow       1.00      1.00      1.00         9
        coyote       0.90      1.00      0.95         9
          crab       1.00      0.89      0.94         9
          crow       0.82      1.00      0.90         9
          deer       0.67      0.89      0.76         9
           dog       1.00      0.89      0.94         9
       dolphin       1.00      0.89      0.94         9
        donkey       0.78      0.78      0.78         9
     dragonfly       0.88      0.78      0.82         9
          duck       1.00      0.67      0.80         9
         eagle       1.00      0.89      0.94         9
      elephant       1.00      0.78      0.88         9
      flamingo       1.00      1.00      1.00         9
           fly       0.78      0.78      0.78         9
           fox       1.00      0.89      0.94         9
          goat       0.78      0.78      0.78         9
      goldfish       1.00      1.00      1.00         9
         goose       0.90      1.00      0.95         9
       gorilla       1.00      0.89      0.94         9
   grasshopper       1.00      0.89      0.94         9
       hamster       0.82      1.00      0.90         9
          hare       0.78      0.78      0.78         9
      hedgehog       1.00      1.00      1.00         9
  hippopotamus       0.89      0.89      0.89         9
      hornbill       1.00      1.00      1.00         9
         horse       0.78      0.78      0.78         9
   hummingbird       1.00      1.00      1.00         9
         hyena       1.00      0.89      0.94         9
     jellyfish       1.00      1.00      1.00         9
      kangaroo       0.90      1.00      0.95         9
         koala       0.89      0.89      0.89         9
      ladybugs       1.00      0.89      0.94         9
       leopard       0.89      0.89      0.89         9
          lion       0.89      0.89      0.89         9
        lizard       1.00      0.67      0.80         9
       lobster       0.89      0.89      0.89         9
      mosquito       0.80      0.89      0.84         9
          moth       0.70      0.78      0.74         9
         mouse       0.57      0.44      0.50         9
       octopus       0.69      1.00      0.82         9
         okapi       0.90      1.00      0.95         9
     orangutan       1.00      1.00      1.00         9
         otter       0.90      1.00      0.95         9
           owl       0.75      0.67      0.71         9
            ox       1.00      0.67      0.80         9
        oyster       1.00      1.00      1.00         9
         panda       1.00      1.00      1.00         9
        parrot       1.00      0.89      0.94         9
pelecaniformes       1.00      1.00      1.00         9
       penguin       1.00      0.89      0.94         9
           pig       0.90      1.00      0.95         9
        pigeon       1.00      1.00      1.00         9
     porcupine       1.00      1.00      1.00         9
        possum       0.88      0.78      0.82         9
       raccoon       0.78      0.78      0.78         9
           rat       0.70      0.78      0.74         9
      reindeer       0.70      0.78      0.74         9
    rhinoceros       0.90      1.00      0.95         9
     sandpiper       1.00      0.89      0.94         9
      seahorse       1.00      1.00      1.00         9
          seal       0.78      0.78      0.78         9
         shark       1.00      1.00      1.00         9
         sheep       0.80      0.89      0.84         9
         snake       0.88      0.78      0.82         9
       sparrow       0.90      1.00      0.95         9
         squid       0.88      0.78      0.82         9
      squirrel       1.00      0.78      0.88         9
      starfish       0.80      0.89      0.84         9
          swan       1.00      1.00      1.00         9
         tiger       1.00      0.78      0.88         9
        turkey       1.00      1.00      1.00         9
        turtle       1.00      0.89      0.94         9
         whale       0.90      1.00      0.95         9
          wolf       1.00      0.89      0.94         9
        wombat       0.90      1.00      0.95         9
    woodpecker       0.90      1.00      0.95         9
         zebra       1.00      1.00      1.00         9

      accuracy                           0.90       810
     macro avg       0.91      0.90      0.90       810
  weighted avg       0.91      0.90      0.90       810

Best / Worst Predictions¶
In [36]:
class_names =  list(test_rn50.class_indices.keys())

best_5_rn50, worst_5_rn50 = get_best_worst(
    y_true_rn50, y_pred_rn50, class_names
)
Best 5 classes (F1 Score):
['shark', 'pigeon', 'turkey', 'swan', 'zebra']

Worst 5 classes (F1 Score):
['mouse', 'bat', 'owl', 'moth', 'rat']
Confusion Matrix¶
In [37]:
plot_confusion_for_classes(
    y_true_rn50, y_pred_rn50, best_5_rn50, worst_5_rn50,
    class_names, model_name = "resnet"
)
No description has been provided for this image
Wrong Predictions¶
In [38]:
# Get wrong predictions
wrong_indices_rn50 = np.where(y_true_rn50 != y_pred_rn50)[0]
print(f"Found {len(wrong_indices_rn50)} misclassified images.")

# plot
show_wrong_predictions(
    y_true_rn50, y_pred_rn50, wrong_indices_rn50,
    test_dir, model_name = "resnet"
)
Found 79 misclassified images.
No description has been provided for this image
Save Predictions¶
In [39]:
# Save Predictions
df_preds_rn50 = pd.DataFrame({
    'filename': [str(os.path.basename(p)) for p in test_rn50.filenames],
    'true_label': [class_names[i] for i in y_true_rn50],
    'predicted_label': [class_names[i] for i in y_pred_rn50],
    'correct': (y_true_rn50 == y_pred_rn50)
})

df_preds_rn50.to_csv(f"{results_dir}/predictions_ResNet50.csv", index=False)

3. Modelling with Data Augmentations¶

3.0. Function to run all 3 models with augmentation¶

In [40]:
image_dir_aug = "./cis730_term_project/images/data_augmented"
os.makedirs(image_dir_aug, exist_ok=True)

def run_augmentation_processes(
        train_dir=train_dir,
        val_dir=val_dir,
        test_dir=test_dir,
        model_name="mobilenet",
        keras_model="MobileNetV2",
        index=0,
        compile_params=compile_params
    ):

    print(f'=================================================')
    print()
    print(f'3.{index+1}. MODEL 1 - {keras_model}')
    print()
    print(f'=================================================')
    print()

    # Load datasets
    print(f'3.{index+1}.0. LOAD DATASETS \n PERFORM IMAGE AUGMENTATIONS ON TRAINING SET.**')
    print()
    if model_name == "efficientnet":
        train_ds = load_image_dataset(
            train_dir, image_size=(300, 300), model_name=model_name, augment=True
        )
        val_ds = load_image_dataset(val_dir, image_size=(300, 300), model_name=model_name)
        test_ds = load_image_dataset(
            test_dir, image_size=(300, 300), model_name=model_name, shuffle=False
            )
    else:
        train_ds = load_image_dataset(train_dir, model_name=model_name, augment=True)
        val_ds = load_image_dataset(val_dir, model_name=model_name)
        test_ds = load_image_dataset(test_dir, model_name=model_name, shuffle=False)
    print()
    print("-------------------------------------------------")
    print()

    # Train model
    print(f'3.{index+1}.1. BUILD / TRAIN {keras_model} MODEL.')
    print()
    compile_params.update(optimizer=tf.keras.optimizers.Adam(0.0005))

    model = build_model(
        model_name,
        num_classes=train_ds.num_classes,
        compile_params=compile_params
    )

    early_stop = get_early_stopping_callback()
    model_check_point = get_check_point_callback(
        checkpoint_path=f"{checkpoint_dir}/{keras_model}CheckPoint_data_augment.keras"
    )

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=10,
        callbacks=[early_stop, model_check_point]
    )
    print()
    print("-------------------------------------------------")
    print()

    # Plot loss/accuracy
    print(f'**Training loss/accuracy for {keras_model} model.')
    print()
    plot_training_loss_accuracy(
        history, model_name=model_name, output_dir=image_dir_aug
    )
    print()
    print("-------------------------------------------------")
    print()

    # Fine Tune
    print(f'3.{index+1}.2. FINE TUNE {keras_model} MODEL.')
    print()
    model.trainable = True
    model.summary(show_trainable=True)
    print()
    print("-------------------------------------------------")
    print()

    model.compile(**compile_params)

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=10,
        callbacks=[early_stop, model_check_point]
    )
    print()
    print("-------------------------------------------------")
    print()

    print(f'3.{index+1}.3. for {keras_model}.**')
    print()
    # Get Predictions
    print(f'**Get Predictions for {keras_model}.**')
    print()
    y_true, y_pred = get_predictions(test_ds, model)

    # Classification
    print(f'**Classification Report for {keras_model}.**')
    print()
    print(classification_report(
        y_true, y_pred, target_names=test_ds.class_indices.keys()
    ))
    print()
    print("-------------------------------------------------")
    print()

    # Get Best and Worst
    print(f'**Best and Worst for {keras_model}.**')
    print()
    class_names =  list(test_ds.class_indices.keys())

    best_5, worst_5 = get_best_worst(
        y_true, y_pred, class_names
    )
    print()
    print("-------------------------------------------------")
    print()

    # Confusion
    print(f'**Confusion Matrix for {keras_model}.**')
    print()
    plot_confusion_for_classes(
        y_true, y_pred, best_5, worst_5, class_names, model_name=model_name,
        output_dir=image_dir_aug
    )
    print()
    print("-------------------------------------------------")
    print()

    print(f'**Wrong predictions for {keras_model}.**')
    print()
    # Get wrong predictions
    wrong_indices = np.where(y_true != y_pred)[0]
    print(f"Found {len(wrong_indices)} misclassified images.")
    print()
    print("-------------------------------------------------")
    print()

    # Show some wrong predictions
    show_wrong_predictions(
        y_true, y_pred, wrong_indices, test_dir, model_name=model_name,
        output_dir=image_dir_aug
    )
    print()
    print("-------------------------------------------------")
    print()

    # Save Predictions
    print(f'**Save Predictions for {keras_model}.**')
    print()
    df_preds = pd.DataFrame({
        'filename': [str(os.path.basename(p)) for p in test_ds.filenames],
        'true_label': [class_names[i] for i in y_true],
        'predicted_label': [class_names[i] for i in y_pred],
        'correct': (y_true == y_pred)
    })

    df_preds.to_csv(f"{results_dir}/predictions_{keras_model}_data_augment.csv", index=False)
    print()
    print("-------------------------------------------------")
    print("-------------------------------------------------")
    print()
    print()
    print()
    print()

Run with augmented data¶

In [41]:
model_names = {
    "mobilenet": "MobileNetV2",
    "efficientnet": "EfficientNetB3",
    "resnet": "ResNet50"
}

for index, (model_name, keras_model) in enumerate(model_names.items()):
    print(index+1, model_name, keras_model)
    run_augmentation_processes(
        model_name=model_name,
        keras_model=keras_model,
        index=index
    )
Output hidden; open in https://colab.research.google.com to view.